/*
 * Copyright (C) 2014 Tobias Brunner
 * Copyright (C) 2014 Andreas Steffen
 * HSR Hochschule fuer Technik Rapperswil
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 */

#include "bliss_param_set.h"

#include <library.h>

#include <stdio.h>
#include <math.h>

typedef struct tuple_t tuple_t;

struct tuple_t {
	int8_t z1;
	int8_t z2;
	uint16_t index;
	uint16_t bits;
	uint32_t code;
};

typedef struct node_t node_t;

struct node_t {
	node_t *next;
	node_t *l;
	node_t *r;
	tuple_t *tuple;
	double p;
	uint16_t depth;
	uint16_t index;
};

static void print_node(node_t *node)
{
	if (node->tuple)
	{
		fprintf(stderr, "(%1d,%2d)", node->tuple->z1, node->tuple->z2);
	}
	else
	{
		fprintf(stderr, "      ");
	}
	fprintf(stderr, "  %18.16f\n", node->p);
}

static double code_node(node_t *node, int *index, uint8_t bits, uint32_t code)
{
	double code_length = 0;

	node->index = (*index)++;

	if (node->tuple)
	{
		node->tuple->code = code;
		node->tuple->bits = bits;
		code_length += node->p * bits;
	}
	if (node->l)
	{
		code_length += code_node(node->l, index, bits + 1, (code << 1));
	}
	if (node->r)
	{
		code_length += code_node(node->r, index, bits + 1, (code << 1) + 1);
	}

	return code_length;

}

static void write_node(node_t *node)
{
	int16_t node_0, node_1, tuple;

	node_0 = node->l ? node->l->index : BLISS_HUFFMAN_CODE_NO_NODE;
	node_1 = node->r ? node->r->index : BLISS_HUFFMAN_CODE_NO_NODE;
	tuple = node->tuple ? node->tuple->index : BLISS_HUFFMAN_CODE_NO_TUPLE;

	printf("\t{ %3d, %3d, %3d },  /* %3d: ", node_0, node_1, tuple, node->index);

	if (node->tuple)
	{
		printf("(%d,%2d) %2u bit%s ", node->tuple->z1, node->tuple->z2,
			   node->tuple->bits, (node->tuple->bits == 1) ? " " : "s");
	}
	printf("*/\n");

	if (node->l)
	{
		write_node(node->l);
	}
	if (node->r)
	{
		write_node(node->r);
	}
}

static void write_header(void)
{
	printf("/*\n");
	printf(" * Copyright (C) 2014 Andreas Steffen\n");
	printf(" * HSR Hochschule fuer Technik Rapperswil\n");
	printf(" *\n");
	printf(" * Optimum Huffman code for BLISS-X signatures\n");
	printf(" *\n");
	printf(" * This file has been automatically generated by the"
		   " bliss_huffman utility\n");
	printf(" * Do not edit manually!\n");
	printf(" */\n\n");
};

static void write_code_tables(int bliss_type, int n_z1, int n_z2, node_t *nodes,
							  tuple_t **tuples)
{
	int index, i, k;
	uint32_t bit;
	double code_length;

	printf("#include \"bliss_huffman_code.h\"\n\n");

	printf("static bliss_huffman_code_node_t nodes[] = {\n");
	index = 0;
	code_length = code_node(nodes, &index, 0, 0);
	write_node(nodes);
	printf("};\n\n");

	printf("static bliss_huffman_code_tuple_t tuples[] = {\n");
	index = 0;
	for (i = 0; i < n_z1; i++)
	{
		if (i > 0)
		{
			printf("\n");
		}
		for (k = 1 - n_z2; k < n_z2; k++)
		{
			printf("\t{ %5u, %2u },  /* %3d: (%1d,%2d) ",
						tuples[index]->code, tuples[index]->bits, index, i, k);
			bit = 1 << (tuples[index]->bits - 1);
			while (bit)
			{
				printf("%s", (tuples[index]->code & bit) ? "1" : "0");
				bit >>= 1;
			}
			printf(" */\n");
			index++;
		}
	}
	printf("};\n\n");
	printf("/* code_length = %6.4f bits/tuple (%d bits) */\n\n",
			   code_length, (int)(512 * code_length + 1));

	printf("bliss_huffman_code_t bliss_huffman_code_%d = {\n", bliss_type);
	printf("\t.n_z1 = %d,\n", n_z1);
	printf("\t.n_z2 = %d,\n", n_z2);
	printf("\t.tuples = tuples,\n");
	printf("\t.nodes  = nodes\n");
	printf("};\n");
}

static void destroy_node(node_t *node)
{
	if (node->l)
	{
		destroy_node(node->l);
	}
	if (node->r)
	{
		destroy_node(node->r);
	}
	free(node->tuple);
	free(node);
}

static void remove_node(node_t *list, node_t **last, node_t *node)
{
	node_t *current, *prev;

	for (current = list->next, prev = list; current;
		 prev = current, current = current->next)
	{
		if (current == node)
		{
			prev->next = current->next;
			if (*last == current)
			{
				*last = prev->next ?: prev;
			}
			break;
		}
	}
}

/**
 * Generate a Huffman code for the optimum encoding of BLISS signatures
 */
int main(int argc, char *argv[])
{
	const bliss_param_set_t *set;
	int dx, bliss_type, depth = 1, groups, groups_left, pairs = 1;
	int i_max = 9, k_max = 8, index_max = (2*k_max - 1) * i_max;
	int i, i_top, k, k_top;
	uint16_t index;
	double p, p_z1[i_max], p_z2[k_max], x_z1[i_max], x_z2[k_max];
	double t, x, x0, p_sum, entropy = 0, erf_i, erf_k, erf_0 = 0;
	tuple_t *tuple, *tuples[index_max];
	node_t *node, *node_l, *node_r, *nodes = NULL;
	node_t *node_list, *node_last;

	if (argc < 2)
	{
		fprintf(stderr, "usage: bliss_huffman <bliss type> [<pairs>]\n");
		exit(1);
	}
	if (argc > 2)
	{
		pairs = atoi(argv[2]);
	}
	fprintf(stderr, "%d code pairs with constant length\n\n", pairs);
	groups_left = groups = pairs >> 1;

	bliss_type = atoi(argv[1]);
	set = bliss_param_set_get_by_id(bliss_type);
	if (!set)
	{
		fprintf(stderr, "bliss type %d unsupported\n", bliss_type);
		exit(1);
	}
	write_header();
	printf("/*\n");
	printf(" * Design: sigma = %u\n", set->sigma);
	printf(" *\n");

	t = 1/(sqrt(2) * set->sigma);

	/* Probability distribution for z1 */
	i_top = (set->B_inf + 255) / 256;
	p_sum = 0;
	x = 0;

	for (i = 0; i < i_top; i++)
	{
		x = min(x + 256, set->B_inf);
		erf_i = erf(t*x);
		p_z1[i] = erf_i - erf_0;
		p_sum += p_z1[i];
		erf_0 = erf_i;
		x_z1[i] = x;
	}

	/* Normalize and print the probability distribution for z1 */
	printf(" *   i  p_z1[i]\n");
	x0 = 0;

	for (i = 0; i < i_top; i++)
	{
		p_z1[i] /= p_sum;
		printf(" *  %2d  %18.16f      %4.0f .. %4.0f\n", i, p_z1[i], x0, x_z1[i]);
		x0 = x_z1[i];
	}
	printf(" *\n");

	/* Probability distribution for z2 */
	dx = 1 << set->d;
	k_top = 1 + set->B_inf / dx;
	x = (dx >> 1) - 0.5;
	p_sum = 0;

	for (k = 0; k < k_top; k++)
	{

		erf_k = erf(t*x) / 2;
		p_z2[k] = (k == 0) ? 2*erf_k : erf_k - erf_0;
		p_sum +=  (k == 0) ? p_z2[k] : 2*p_z2[k];
		erf_0 = erf_k;
		x_z2[k] = x;
		x += dx;
	}

	/* Normalize the probability distribution for z2 */
	for (k = 0; k < k_top; k++)
	{
		p_z2[k] /= p_sum;
	}

	/* Print the probability distribution for z2 */
	printf(" *   k  p_z2[k]  dx = %d\n", dx);

	for (k = 1 - k_top; k < k_top; k++)
	{

		printf(" *  %2d  %18.16f  ",k, p_z2[abs(k)]);
		if (k < 0)
		{
			printf(" %7.1f ..%7.1f\n", -x_z2[-k], -x_z2[-k-1]);
		}
		else if (k == 0)
		{
			printf(" %7.1f ..%7.1f\n", -x_z2[k], x_z2[k]);
		}
		else
		{
			printf(" %7.1f ..%7.1f\n", x_z2[k-1], x_z2[k]);
		}
	}
	printf(" *\n");

	/* Compute probabilities of tuples (z1, z2) */
	INIT(node_list);
	node_last = node_list;
	printf(" *  (i, k)  p\n");
	p_sum =0;
	index = 0;

	for (i = 0; i < i_top; i++)
	{
		for (k = 1 - k_top; k < k_top; k++)
		{
			p = p_z1[i] * p_z2[abs(k)];
			printf(" *  (%1d,%2d)  %18.16f\n", i, k, p);
			p_sum += p;
			entropy += -log(p) * p;

			INIT(tuple,
				.z1 = i,
				.z2 = k,
				.index = index,
			);
			tuples[index++] = tuple;

			INIT(node,
				.p = p,
				.tuple = tuple,
			);
			node_last->next = node;
			node_last = node;
		}
		printf(" *\n");
	}
	entropy /= log(2);
	printf(" *  p_sum   %18.16f\n", p_sum);
	printf(" *\n");
	printf(" * entropy = %6.4f bits/tuple (%d bits)\n",
			   entropy, (int)(512 * entropy));
	printf(" */\n\n");

	/* Build Huffman tree */
	while (node_list->next != node_last)
	{
		node_r = node_l = NULL;

		for (node = node_list->next; node; node = node->next)
		{
			if (pairs > 0)
			{
				if (!node->tuple)
				{
					continue;
				}
			}
			else if (groups_left > 0)
			{
				if (node->tuple || node->depth != depth)
				{
					continue;
				}
			}
			if (node_r == NULL || node->p < node_r->p)
			{
				node_l = node_r;
				node_r = node;
			}
			else if (node_l == NULL || node->p < node_l->p)
			{
				node_l = node;
			}
		}

		INIT(node,
			.l = node_l,
			.r = node_r,
			.p = node_l->p + node_r->p,
			.depth = 1 + max(node_l->depth, node_r->depth),
			.tuple = NULL,
		);
		print_node(node_r);
		print_node(node_l);
		fprintf(stderr, "        %18.16f", node->p);

		remove_node(node_list, &node_last, node_l);
		remove_node(node_list, &node_last, node_r);
		node_last->next = node;
		node_last = node;

		if (pairs > 0)
		{
			pairs--;
		}
		else if (groups > 0)
		{
			if (--groups_left == 0)
			{
				groups >>= 1;
				groups_left = groups;
				depth++;
			}
		}
		fprintf(stderr, "\n\n");
	}


	nodes = node_list->next;

	write_code_tables(bliss_type, i_top, k_top, nodes, tuples);

	destroy_node(nodes);
	destroy_node(node_list);
	exit(0);
}

